import torch
import torch.distributions
from utils.datasets.paths import get_svhn_path
from utils.datasets.svhn import get_SVHN_labels
from utils.datasets.svhn_augmentation import get_SVHN_augmentation
from utils.datasets import TINY_LENGTH
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np

from .svhn_validation_extra_split import SVHNValidationExtraSplit
from .cifar_semi_tiny_partition import BalancedSampler
from .loading_utils import load_teacher_data
from utils.datasets.paths import get_tiny_images_files
from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from .svhn_semi_tiny_partition import AllValidSampler

def get_svhn_extra_topk(dataset_classifications_path, teacher_model, samples_per_class, svhn_extra_val_split,
                            class_tpr_min=None, calibrate_temperature=False,
                            id_class_balanced=True, soft_labels=True, batch_size=128,
                            all_sampler=False, augm_type='default', subdivide_epochs=False,
                            num_workers=8,
                            id_config_dict=None, ssl_config=None):

    model_confidences, _, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                         class_tpr_min=class_tpr_min,
                                                                         od_exclusion_threshold=None,
                                                                         calibrate_temperature=calibrate_temperature,
                                                                         ssl_config=ssl_config)


    augm_config = {}
    transform = get_SVHN_augmentation(augm_type, config_dict=augm_config)

    if all_sampler:
        top_k_samples = 1e8
    else:
        top_k_samples = samples_per_class

    top_dataset = SVHNPlusSVHNExtraTopKPartition(model_confidences, samples_per_class=top_k_samples,
                                                          transform_base=transform, min_conf=class_thresholds,
                                                          temperature=temperature,
                                                          svhn_extra_val_split=svhn_extra_val_split,
                                                          soft_labels=soft_labels)

    if all_sampler:
        balanced_sampler = AllValidSampler(top_dataset, samples_per_class)
    else:
        balanced_sampler = BalancedSampler(top_dataset, subdivide_epochs)

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=balanced_sampler, batch_size=batch_size, num_workers=num_workers)


    if id_config_dict is not None:
        id_config_dict['Dataset'] ='SVHN-SSL-Extra'
        id_config_dict['Extra validation split'] = svhn_extra_val_split
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['Samples per class'] = samples_per_class
        id_config_dict['All Sampler'] = all_sampler
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Class balanced'] = id_class_balanced
        id_config_dict['Augmentation'] = augm_config

    return top_loader


class SVHNPlusSVHNExtraTopKPartition(Dataset):
    def __init__(self, model_logits, samples_per_class, transform_base, min_conf,
                 svhn_extra_val_split=True,
                 temperature=1.0,soft_labels=True, preload=True):
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.temperature = temperature

        if svhn_extra_val_split:
            svhn_path = get_svhn_path()
            self.svhn_extra = SVHNValidationExtraSplit(svhn_path, split='extra-split', transform=transform_base)
        else:
            svhn_path = get_svhn_path()
            self.svhn_extra = datasets.SVHN(svhn_path, split='extra', transform=transform_base)

        #LOGITS for SVHN extra
        assert len(model_logits) == len(self.svhn_extra)
        self.model_logits = model_logits
        predicted_max_conf, predicted_class = torch.max(torch.softmax(self.model_logits,dim=1), dim=1)

        class_labels = get_SVHN_labels()
        self.num_classes = len(class_labels)
        self.train_dataset = datasets.SVHN(svhn_path, split='train', transform=transform_base)

        self.num_train_samples = len(self.train_dataset)
        self.train_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        self.train_class_idcs = []
        targets_tensor = torch.LongTensor(self.train_dataset.labels)
        for i in range(self.num_classes):
            train_i = torch.nonzero(targets_tensor == i, as_tuple=False).squeeze()
            self.train_class_idcs.append(train_i)
            self.train_per_class[i] = len(train_i)

        self.in_use_indices = []
        self.valid_indices = []
        self.semi_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        min_sampels_per_class = int(1e13)
        max_samples_per_class = 0

        for i in range(self.num_classes):
            min_conf_flag = predicted_max_conf >= min_conf[i]
            included_correct_class_bool_idcs = (predicted_class == i)  & min_conf_flag

            included_correct_class_linear_idcs = torch.nonzero(included_correct_class_bool_idcs, as_tuple=False).squeeze()
            included_correct_class_confidences = predicted_max_conf[included_correct_class_bool_idcs]
            included_correct_class_sort_idcs = torch.argsort(included_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, len(included_correct_class_linear_idcs) ))
            class_i_idcs = included_correct_class_linear_idcs[included_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(included_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.semi_per_class[i] = len(class_i_idcs)

            min_sampels_per_class = min(min_sampels_per_class, len(class_i_idcs))
            max_samples_per_class = max(max_samples_per_class, len(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {len(class_i_idcs)}')

        self.num_semi_samples = 0
        self.length = self.num_train_samples
        for i in range(self.num_classes):
            self.num_semi_samples += self.semi_per_class[i]
            self.length +=  self.semi_per_class[i]

        #internal idx ranges
        self.train_idx_ranges = []
        self.semi_idx_ranges = []

        train_idx_start = 0
        semi_idx_start = self.num_train_samples
        for i in range(self.num_classes):
            i_train_samples = self.train_per_class[i]
            i_semi_samples = self.semi_per_class[i]

            train_idx_next = train_idx_start + i_train_samples
            semi_idx_next = semi_idx_start + i_semi_samples
            self.train_idx_ranges.append( (train_idx_start, train_idx_next))
            self.semi_idx_ranges.append( (semi_idx_start, semi_idx_next))

            train_idx_start = train_idx_next
            semi_idx_start = semi_idx_next

        self.cum_train_lengths = torch.cumsum(self.train_per_class, dim=0)
        self.cum_semi_lengths = torch.cumsum(self.semi_per_class, dim=0)

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels}'
              f'  -  Target Samples per class { self.samples_per_class} - Train Samples {self.num_train_samples}')
        print(f'Min Semi Samples {min_sampels_per_class} - Max Semi samples {max_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')


    #if verbose exclude, include all indices that fulfill the conf requirement but that are outside of the top-k range
    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_train_image(self, cifar_idx):
        img, label = self.train_dataset[cifar_idx]
        if self.soft_labels:
            one_hot_label = torch.zeros(self.num_classes)
            one_hot_label[label] = 1.0
            return img, one_hot_label
        else:
            return img, label

    def _load_svhn_extra_image(self, class_idx, extra_lin_idx):
        valid_index = self.in_use_indices[class_idx][extra_lin_idx].item()
        img, _ = self.svhn_extra[valid_index]

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
            if index < self.num_train_samples:
                class_idx = torch.nonzero(self.cum_train_lengths > index, as_tuple=False)[0]
                if class_idx > 0:
                    sample_idx = index - self.cum_train_lengths[class_idx - 1]
                else:
                    sample_idx = index
                train_class_idx = self.train_class_idcs[class_idx][sample_idx]
                return self._load_train_image(train_class_idx)
            else:
                index_semi = index - self.num_train_samples
                class_idx = torch.nonzero(self.cum_semi_lengths > index_semi, as_tuple=False)[0]
                if class_idx > 0:
                    sample_idx = index_semi - self.cum_semi_lengths[class_idx - 1]
                else:
                    sample_idx = index_semi

                return self._load_svhn_extra_image(class_idx, sample_idx)

    def __len__(self):
        return self.length
